/*******************************************************************************

  Pilot Intelligence Library
    http://www.pilotintelligence.com/

  ----------------------------------------------------------------------------

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.

*******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <Poco/Net/StreamSocket.h>
#include <Poco/Net/DatagramSocket.h>
#include <Poco/Net/Socket.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/Net/ServerSocket.h>
#include <Poco/Net/SocketStream.h>
#include <Poco/Net/TCPServer.h>
#include <Poco/Net/MulticastSocket.h>
#include <Poco/Net/NetworkInterface.h>
#include <Poco/Net/IPAddress.h>

#include "base/debug/debug_config.h"
#include "Socket++_poco.h"


namespace pi {


////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

class RSocketImpl
{
public:
    RSocketImpl() {
        m_isOpened = 0;
        m_isServer = 0;
        m_socketType = SOCKET_TCP;
    }

    virtual ~RSocketImpl() {
        close();
    }


    ////////////////////////////////////////////////////
    /// Data Transmission
    ////////////////////////////////////////////////////

    int send(const void* dat, int len) {
        // try to send data to peer
        try {
            if( m_socketType == SOCKET_TCP ) {
                return socketTCP.sendBytes(dat, len);
            } else if( m_socketType == SOCKET_UDP ) {
                return socketUDP.sendBytes(dat, len);
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                return socketUDP_mc.sendBytes(dat, len);
            }
        } catch( ... ) {
            return -1;
        }

        return -2;
    }

    int send(const std::string& msg) {
        return send(msg.c_str(), msg.size() + 1);
    }


    int recv(void* dat, int len, RSocketAddress* sender=NULL) {
        int n = -2;
        Poco::Net::SocketAddress sa("0.0.0.0", 0);

        // try receving data to peer
        try {
            if( m_socketType == SOCKET_TCP ) {
                n = socketTCP.receiveBytes(dat, len);
            } else if( m_socketType == SOCKET_UDP ) {
                n = socketUDP.receiveFrom(dat, len, sa);
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                n = socketUDP_mc.receiveFrom(dat, len, sa);
            }
        } catch( ... ) {
            return -1;
        }

        // get peer address
        if( sender != NULL ) {
            sender->address = sa.host().toString();
            sender->port = sa.port();
            sender->type = m_socketType;
        }

        return n;
    }

    int recv(std::string& msg, int maxLen = 4096, RSocketAddress* sender=NULL) {
        char *buf;
        int status = 0;

        buf = new char[maxLen + 1];
        memset(buf, 0, maxLen + 1);
        msg = "";

        status = recv(buf, maxLen, sender);
        if ( status > 0 ) msg = buf;

        delete [] buf;

        return status;
    }

    int recvUntil(void *dat, int len, RSocketAddress* sender=NULL) {
        uint8_t     *p;
        int         ret, read, readed = 0;

        p    = (uint8_t*) dat;
        read = len;

        while(1) {
            ret = recv(p, read, sender);
            if( ret < 0 ) return ret;

            readed += ret;
            p      += ret;

            if( readed >= len ) return readed;

            read = len - readed;
        }

        return -1;
    }

    int pool(int64_t timeout, int mode) {
        try {
            if( m_socketType == SOCKET_TCP ) {
                return socketTCP.poll(timeout, mode);
            } else if( m_socketType == SOCKET_UDP ) {
                return socketUDP.poll(timeout, mode);
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                return socketUDP_mc.poll(timeout, mode);
            }
        } catch( ... ) {
            return -1;
        }
    }

    int available(void) {
        try {
            if( m_socketType == SOCKET_TCP ) {
                return socketTCP.available();
            } else if( m_socketType == SOCKET_UDP ) {
                return socketUDP.available();
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                return socketUDP_mc.available();
            }
        } catch( ... ) {
            return -1;
        }
    }


    ////////////////////////////////////////////////////
    /// socket creation & close
    ////////////////////////////////////////////////////

    int create(void) {
        if( m_isOpened ) return -1;
        else {
            m_isOpened = 1;
            return 0;
        }
    }

    int close(void) {
        int ret = 0;

        if( m_isOpened ) {
            try {
                if( m_socketType == SOCKET_TCP ) {
                    if( m_isServer == 1 )
                        socketTCP_server.close();
                    else
                        socketTCP.close();
                } else if( m_socketType == SOCKET_UDP ) {
                    socketUDP.close();
                } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                    socketUDP_mc.close();
                }
            } catch( ... ) {
                return -2;
            }

            ret = 0;
        } else {
            ret = -1;
        }

        m_socketType = SOCKET_TCP;
        m_isOpened = 0;
        m_isServer = 0;

        return ret;
    }


    ////////////////////////////////////////////////////
    /// start server or client
    ////////////////////////////////////////////////////

    int startServer(int port, RSocketType t=SOCKET_TCP) {
        int ret = 0;

        m_socketType = t;
        m_isServer = 1;

        // check socket is opened
        if( 0 != create() ) return -1;

        // begin the socket
        if( m_socketType == SOCKET_TCP ) {
            if( 0 != bind(port) ) {
                ret = -2;
                goto START_SERVER_ERR;
            }

            if( 0 != listen() ) {
                ret = -3;
                goto START_SERVER_ERR;
            }
        } else if( m_socketType == SOCKET_UDP ) {
            if( 0 != bind(port) ) {
                ret = -2;
                goto START_SERVER_ERR;
            }
        } else {
            dbg_pe("please specific group address!\n");
            ret = -4;
        }

    START_SERVER_ERR:
        if( ret != 0 ) m_isOpened = 0;

        return ret;
    }

    int startServer(const std::string& addr, int port,
                    RSocketType t=SOCKET_UDP_MULTICAST) {
        int ret = 0;

        m_socketType = t;
        m_isServer = 1;

        // check socket is opened
        if( 0 != create() ) return -1;

        // begin the socket        
        if( m_socketType == SOCKET_TCP ) {
            if( 0 != bind(port) ) {
                ret = -2;
                goto START_SERVER_ERR;
            }

            if( 0 != listen() ) {
                ret = -3;
                goto START_SERVER_ERR;
            }
        } else if( m_socketType == SOCKET_UDP ) {
            if( 0 != bind(port) ) {
                ret = -2;
                goto START_SERVER_ERR;
            }
        } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
            if( 0 != bind(port) ) {
                ret = -2;
                goto START_SERVER_ERR;
            }

            try {
                Poco::Net::IPAddress ip(addr);
                socketUDP_mc.joinGroup(ip);
            } catch( ... ) {
                ret = -4;
            }
        }

    START_SERVER_ERR:
        if( ret != 0 ) m_isOpened = 0;

        return ret;
    }

    int startClient(const std::string& host, int port,
                    RSocketType t=SOCKET_TCP) {
        m_socketType = t;
        m_isServer = 0;

        // check socket is opened
        if( 0 != create() ) return -1;

        // connect to server
        if( 0 != connect(host, port) ) {
            m_isOpened = 0;
            return -2;
        }

        return 0;
    }


    ////////////////////////////////////////////////////
    /// server functions
    ////////////////////////////////////////////////////

    int bind(int port) {
        if( !isOpened() ) return -1;

        try {
            Poco::Net::SocketAddress sa(Poco::Net::IPAddress(), port);

            if( m_socketType == SOCKET_TCP ) {
                socketTCP_server.bind(port);
            } else if( m_socketType == SOCKET_UDP ) {
                socketUDP.bind(sa);
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                socketUDP_mc.bind(sa);
            }

            return 0;
        } catch( ... ) {
            return -1;
        }
    }

    int listen(void) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                socketTCP_server.listen();
                return 0;
            } else {
                return -1;
            }
        } catch( ... ) {
            return -1;
        }
    }

    int accept(RSocket& s) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                RSocketImpl *impl = s.m_impl.get();

                impl->m_isOpened = 1;
                impl->m_isServer = 2;
                impl->m_socketType = SOCKET_TCP;
                impl->socketTCP_server = this->socketTCP_server;

                impl->socketTCP = socketTCP_server.acceptConnection();

                return 0;
            } else {
                return -1;
            }
        } catch( ... ) {
            return -1;
        }
    }


    ////////////////////////////////////////////////////
    /// client functions
    ////////////////////////////////////////////////////

    int connect(const std::string& host, int port) {
        if( !isOpened() ) return -1;

        try {
            Poco::Net::SocketAddress sa(host, port);

            if( m_socketType == SOCKET_TCP ) {
                socketTCP.connect(sa);
                return 0;
            } else if( m_socketType == SOCKET_UDP ) {
                socketUDP.connect(sa);
                return 0;
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                socketUDP_mc.connect(sa);
            }
        } catch( ... ) {
            return -1;
        }

        return 0;
    }


    ////////////////////////////////////////////////////
    /// get address
    ////////////////////////////////////////////////////

    int getMyAddress(RSocketAddress &a) {
        if( !isOpened() ) return -1;

        Poco::Net::SocketAddress addr;

        try {
            if( m_socketType == SOCKET_TCP ) {
                if( m_isServer == 1 ) {
                    addr = socketTCP_server.address();
                } else {
                    addr = socketTCP.address();
                }
            } else if( m_socketType == SOCKET_UDP ) {
                addr = socketUDP.address();
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                addr = socketUDP_mc.address();
            }
        } catch( ... ) {
            return -1;
        }

        // copy address
        a.address = addr.host().toString();
        a.port = addr.port();
        a.type = m_socketType;

        return 0;
    }

    // FIXME: need to change to getPeerAddress
    int getClientAddress(RSocketAddress &a) {
        if( !isOpened() ) return -1;

        Poco::Net::SocketAddress addr;

        try {
            if( m_socketType == SOCKET_TCP ) {
                addr = socketTCP.address();
            } else if( m_socketType == SOCKET_UDP ) {
                addr = socketUDP.address();
            } else if( m_socketType == SOCKET_UDP_MULTICAST ) {
                addr = socketUDP_mc.address();
            }
        } catch( ... ) {
            return -1;
        }

        // copy address
        a.address = addr.host().toString();
        a.port = addr.port();
        a.type = m_socketType;

        return 0;
    }


    ////////////////////////////////////////////////////
    /// socket options
    ////////////////////////////////////////////////////

    int setNonBlocking(int nb=1) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                if( m_isServer == 1 ) socketTCP_server.setBlocking(nb);
                else                  socketTCP.setBlocking(nb);
            } else if( m_socketType == SOCKET_UDP ) {
                socketUDP.setBlocking(nb);
            } else {
                socketUDP_mc.setBlocking(nb);
            }
        } catch( ... ) {
            return -1;
        }

        return 0;
    }

    int getNonBlocking(void) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                if( m_isServer == 1 ) return socketTCP_server.getBlocking();
                else                  return socketTCP.getBlocking();
            } else if( m_socketType == SOCKET_UDP ) {
                return socketUDP.getBlocking();
            } else {
                return socketUDP_mc.getBlocking();
            }
        } catch( ... ) {
            return -1;
        }

        return 0;
    }

    int setReuseAddr(int reuse=1) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                if( m_isServer == 1 ) socketTCP_server.setReuseAddress(reuse);
                else                  socketTCP.setReuseAddress(reuse);
            } else if( m_socketType == SOCKET_UDP ) {
                socketUDP.setReuseAddress(reuse);
            } else {
                socketUDP_mc.setReuseAddress(reuse);
            }
        } catch( ... ) {
            return -1;
        }

        return 0;
    }

    int getReuseAddr(void) {
        if( !isOpened() ) return -1;

        try {
            if( m_socketType == SOCKET_TCP ) {
                if( m_isServer == 1 ) return socketTCP_server.getReuseAddress();
                else                  return socketTCP.getReuseAddress();
            } else if( m_socketType == SOCKET_UDP ) {
                return socketUDP.getReuseAddress();
            } else {
                return socketUDP_mc.getReuseAddress();
            }
        } catch( ... ) {
            return -1;
        }

        return 0;
    }


    ////////////////////////////////////////////////////
    /// status
    ////////////////////////////////////////////////////

    int isOpened(void) { return m_isOpened; }
    bool isServer(void) { return m_isServer; }

protected:
    int                         m_isOpened;                     ///< socket is opened or not
    int                         m_isServer;                     ///< server or client
    RSocketType                 m_socketType;                   ///< socket type


    Poco::Net::ServerSocket     socketTCP_server;               ///< TCP server
    Poco::Net::StreamSocket     socketTCP;                      ///< TCP socket

    Poco::Net::DatagramSocket   socketUDP;                      ///< UDP

    Poco::Net::MulticastSocket  socketUDP_mc;                   ///< UDP multicast
};


////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////

RSocket::RSocket() : m_impl(new RSocketImpl)
{

}

RSocket::~RSocket()
{

}

int RSocket::startServer(int port, RSocketType t)
{
    return m_impl->startServer(port, t);
}

int RSocket::startServer(const std::string& addr, int port, RSocketType t)
{
    return m_impl->startServer(addr, port, t);
}

int RSocket::startClient(const std::string& host, int port, RSocketType t)
{
    return m_impl->startClient(host, port, t);
}

int RSocket::close(void)
{
    return m_impl->close();
}

int RSocket::send(const void* dat, int len)
{
    return m_impl->send(dat, len);
}

int RSocket::send(const std::string& msg)
{
    return m_impl->send(msg);
}


int RSocket::recv(void* dat, int len, RSocketAddress* sender)
{
    return m_impl->recv(dat, len, sender);
}

int RSocket::recv(std::string& msg, int maxLen, RSocketAddress* sender)
{
    return m_impl->recv(msg, maxLen, sender);
}

int RSocket::recvUntil(void *dat, int len, RSocketAddress* sender)
{
    return m_impl->recvUntil(dat, len, sender);
}

int RSocket::pool(int64_t timeout, int mode)
{
    return m_impl->pool(timeout, mode);
}

int RSocket::available(void)
{
    return m_impl->available();
}


int RSocket::bind(int port)
{
    return m_impl->bind(port);
}

int RSocket::listen(void)
{
    return m_impl->listen();
}

int RSocket::accept(RSocket& s)
{
    return m_impl->accept(s);
}

int RSocket::connect(const std::string& host, int port)
{
    return m_impl->connect(host, port);
}


int RSocket::getMyAddress(RSocketAddress &a)
{
    return m_impl->getMyAddress(a);
}

int RSocket::getClientAddress(RSocketAddress &a)
{
    return m_impl->getClientAddress(a);
}


int RSocket::setNonBlocking(int nb)
{
    return m_impl->setNonBlocking(nb);
}

int RSocket::getNonBlocking(void)
{
    return m_impl->getNonBlocking();
}

int RSocket::setReuseAddr(int reuse)
{
    return m_impl->setReuseAddr(reuse);
}

int RSocket::getReuseAddr(void)
{
    return m_impl->getReuseAddr();
}


int RSocket::isOpened(void)
{
    return m_impl->isOpened();
}

bool RSocket::isServer(void)
{
    return m_impl->isServer();
}


} // end of namespace pi

